Multi-layer Feed-forward NN with ReLU activation

Feed-forward NN with 256 ReLU neurons for each of 2 hidden layers.


In [1]:
import tensorflow as tf
import numpy as np

In [18]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

In [7]:
# each layer consists of a f(xW + b) where f is a ReLU activation function
# For ReLU units, a study published in 2015 by He et al. demonstrates that the variance
# of weights in a network should be 2/n_in , where n_in is the number inputs coming into in the neuron.
def layer(input, weight_shape, bias_shape):
    weight_stddev = (2.0 / weight_shape[0])**0.5 # variance=2/n_in as explained above
    w_init = tf.random_normal_initializer(stddev=weight_stddev)
    bias_init = tf.constant_initializer(value=0)
    W = tf.get_variable("W", weight_shape, initializer=w_init)
    b = tf.get_variable("b", bias_shape, initializer=bias_init)
    return tf.nn.relu(tf.matmul(input, W) + b)

In [8]:
# inference is the way we pass our input x through the nn of 2 hidden layers and 1 output layer
def inference(x):
    with tf.variable_scope("hidden_1"):
        hidden_1 = layer(x, [784, 256], [256])
    with tf.variable_scope("hidden_2"):
        hidden_2 = layer(hidden_1, [256, 256], [256])
    with tf.variable_scope("output"):
        output = layer(hidden_2, [256, 10], [10])
    return output

In [14]:
# softmax moved form inference to loss function for performance improvement
def loss(output, y):
    xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=output, labels=y)
    loss = tf.reduce_mean(xentropy)
    return loss

In [15]:
def training(cost, global_step):
    tf.summary.tensor_summary("cost", cost)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    train_op = optimizer.minimize(cost, global_step=global_step)
    return train_op

In [19]:
def evaluate(output, y):
    # compare indices of predicted class, if equal (correct classification) set 1 otherwise 0
    correct_prediction = tf.equal(tf.argmax(output, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    return accuracy

In [20]:
# Parameters
learning_rate = 0.01
training_epochs = 100
batch_size = 100
display_step = 1

In [ ]:
from tqdm import tqdm

# program flow
with tf.Graph().as_default():
    # mnist data image of shape 28*28=784
    x = tf.placeholder("float", [None, 784])
    # 0-9 digits recognition => 10 classes
    y = tf.placeholder("float", [None, 10])
    output = inference(x)
    cost = loss(output, y)
    global_step = tf.Variable(0, name='global_step', trainable=False)
    train_op = training(cost, global_step)
    eval_op = evaluate(output, y)

    # tf.merge_all_summaries in order to collect all summary statistics
    # use a tf.train.SummaryWriter to write the log to disk.
    summary_op = tf.summary.merge_all()
    saver = tf.train.Saver()
    sess = tf.Session()
    # write to tensorboard graph api
    summary_writer = tf.summary.FileWriter(
        "logistic_logs/", graph=sess.graph)
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    # training cycle
    for epoch in tqdm(range(training_epochs)):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples / batch_size)

        # Loop over all batches
        for i in range(total_batch):
            mbatch_x, mbatch_y = mnist.train.next_batch(batch_size)

            # Fit training using batch data
            feed_dict = {x: mbatch_x, y: mbatch_y}
            sess.run(train_op, feed_dict=feed_dict)

            # Compute average loss
            minibatch_cost = sess.run(cost, feed_dict=feed_dict)
            avg_cost += minibatch_cost / total_batch
        # Display logs per epoch step
        if epoch % display_step == 0:
            val_feed_dict = {
                x: mnist.validation.images,
                y: mnist.validation.labels
            }
            accuracy = sess.run(eval_op, feed_dict=val_feed_dict)
            print("Validation Error in epoch %s: %.11f" % (epoch, 1 - accuracy))
            summary_str = sess.run(summary_op, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str, sess.run(global_step))
            saver.save(
                sess,
                "logistic_logs/model-checkpoint",
                global_step=global_step)

    test_feed_dict = {x: mnist.test.images, y: mnist.test.labels}
    accuracy = sess.run(eval_op, feed_dict=test_feed_dict)
    print("Test Accuracy:", accuracy)


  1%|          | 1/100 [00:02<04:49,  2.93s/it]
Validation Error in epoch 0: 0.52779999375
Validation Error in epoch 1: 0.43000000715
  3%|▎         | 3/100 [00:08<04:42,  2.91s/it]
Validation Error in epoch 2: 0.38919997215
  4%|▍         | 4/100 [00:11<04:39,  2.91s/it]
Validation Error in epoch 3: 0.29780000448
  5%|▌         | 5/100 [00:14<04:36,  2.92s/it]
Validation Error in epoch 4: 0.28560000658
Validation Error in epoch 5: 0.28020000458
  7%|▋         | 7/100 [00:20<04:38,  3.00s/it]
Validation Error in epoch 6: 0.27600002289
Validation Error in epoch 7: 0.27120000124
  9%|▉         | 9/100 [00:27<04:39,  3.07s/it]
Validation Error in epoch 8: 0.26399999857
 10%|█         | 10/100 [00:29<04:33,  3.03s/it]
Validation Error in epoch 9: 0.26260000467
 11%|█         | 11/100 [00:33<04:33,  3.07s/it]
Validation Error in epoch 10: 0.25800001621
Validation Error in epoch 11: 0.25580000877
 13%|█▎        | 13/100 [00:39<04:28,  3.09s/it]
Validation Error in epoch 12: 0.25300002098
 14%|█▍        | 14/100 [00:42<04:21,  3.04s/it]
Validation Error in epoch 13: 0.25180000067
Validation Error in epoch 14: 0.25139999390
 16%|█▌        | 16/100 [00:48<04:12,  3.00s/it]
Validation Error in epoch 15: 0.24779999256
 17%|█▋        | 17/100 [00:51<04:13,  3.05s/it]
Validation Error in epoch 16: 0.24879997969
Validation Error in epoch 17: 0.22780001163
 19%|█▉        | 19/100 [00:57<04:09,  3.08s/it]
Validation Error in epoch 18: 0.18279999495
 20%|██        | 20/100 [01:00<04:05,  3.07s/it]
Validation Error in epoch 19: 0.17159998417
Validation Error in epoch 20: 0.16820001602
 22%|██▏       | 22/100 [01:06<03:54,  3.01s/it]
Validation Error in epoch 21: 0.16119998693
 23%|██▎       | 23/100 [01:09<03:50,  2.99s/it]
Validation Error in epoch 22: 0.15920001268
Validation Error in epoch 23: 0.15700000525
 25%|██▌       | 25/100 [01:15<03:43,  2.98s/it]
Validation Error in epoch 24: 0.15799999237
 26%|██▌       | 26/100 [01:18<03:39,  2.96s/it]
Validation Error in epoch 25: 0.15340000391
 27%|██▋       | 27/100 [01:21<03:35,  2.95s/it]
Validation Error in epoch 26: 0.15380001068
Validation Error in epoch 27: 0.15179997683
 29%|██▉       | 29/100 [01:27<03:31,  2.98s/it]
Validation Error in epoch 28: 0.14999997616
 30%|███       | 30/100 [01:30<03:28,  2.98s/it]
Validation Error in epoch 29: 0.15020000935
Validation Error in epoch 30: 0.14800000191
 32%|███▏      | 32/100 [01:36<03:23,  2.99s/it]
Validation Error in epoch 31: 0.14859998226
 33%|███▎      | 33/100 [01:39<03:20,  2.99s/it]
Validation Error in epoch 32: 0.14840000868
Validation Error in epoch 33: 0.14780002832
 35%|███▌      | 35/100 [01:45<03:14,  2.99s/it]
Validation Error in epoch 34: 0.14679998159
 36%|███▌      | 36/100 [01:48<03:11,  2.99s/it]
Validation Error in epoch 35: 0.14539998770
Validation Error in epoch 36: 0.14459997416
 38%|███▊      | 38/100 [01:54<03:11,  3.09s/it]
Validation Error in epoch 37: 0.14480000734
 39%|███▉      | 39/100 [01:57<03:09,  3.10s/it]
Validation Error in epoch 38: 0.14539998770
Validation Error in epoch 39: 0.14459997416
 41%|████      | 41/100 [02:04<03:07,  3.18s/it]
Validation Error in epoch 40: 0.14380002022
 42%|████▏     | 42/100 [02:07<03:03,  3.17s/it]
Validation Error in epoch 41: 0.14380002022
 43%|████▎     | 43/100 [02:10<03:01,  3.18s/it]
Validation Error in epoch 42: 0.14319998026
Validation Error in epoch 43: 0.14219999313
 45%|████▌     | 45/100 [02:17<02:56,  3.21s/it]
Validation Error in epoch 44: 0.14259999990
 46%|████▌     | 46/100 [02:20<02:50,  3.16s/it]
Validation Error in epoch 45: 0.14259999990
Validation Error in epoch 46: 0.14240002632
 48%|████▊     | 48/100 [02:26<02:44,  3.16s/it]
Validation Error in epoch 47: 0.14200001955
 49%|████▉     | 49/100 [02:29<02:40,  3.14s/it]
Validation Error in epoch 48: 0.14060002565
Validation Error in epoch 49: 0.14160001278
 51%|█████     | 51/100 [02:35<02:28,  3.03s/it]
Validation Error in epoch 50: 0.14039999247
 52%|█████▏    | 52/100 [02:38<02:23,  2.98s/it]
Validation Error in epoch 51: 0.14160001278
Validation Error in epoch 52: 0.14060002565
 54%|█████▍    | 54/100 [02:44<02:19,  3.04s/it]
Validation Error in epoch 53: 0.13999998569
 55%|█████▌    | 55/100 [02:47<02:16,  3.02s/it]
Validation Error in epoch 54: 0.13859999180
Validation Error in epoch 55: 0.13959997892
 57%|█████▋    | 57/100 [02:53<02:09,  3.01s/it]
Validation Error in epoch 56: 0.13880002499
 58%|█████▊    | 58/100 [02:56<02:05,  2.98s/it]
Validation Error in epoch 57: 0.13819998503
Validation Error in epoch 58: 0.13940000534
 60%|██████    | 60/100 [03:02<02:02,  3.05s/it]
Validation Error in epoch 59: 0.13880002499
 61%|██████    | 61/100 [03:05<01:56,  3.00s/it]
Validation Error in epoch 60: 0.13840001822
Validation Error in epoch 61: 0.13779997826
 63%|██████▎   | 63/100 [03:11<01:49,  2.97s/it]
Validation Error in epoch 62: 0.13819998503
 64%|██████▍   | 64/100 [03:14<01:47,  2.98s/it]
Validation Error in epoch 63: 0.13859999180
Validation Error in epoch 64: 0.13739997149
 66%|██████▌   | 66/100 [03:20<01:40,  2.95s/it]
Validation Error in epoch 65: 0.13840001822
 67%|██████▋   | 67/100 [03:23<01:36,  2.94s/it]
Validation Error in epoch 66: 0.13639998436
Validation Error in epoch 67: 0.13700002432
 69%|██████▉   | 69/100 [03:29<01:30,  2.92s/it]
Validation Error in epoch 68: 0.13599997759
 70%|███████   | 70/100 [03:31<01:27,  2.90s/it]
Validation Error in epoch 69: 0.13760000467
Validation Error in epoch 70: 0.13739997149
 72%|███████▏  | 72/100 [03:37<01:21,  2.90s/it]
Validation Error in epoch 71: 0.13679999113
 73%|███████▎  | 73/100 [03:40<01:17,  2.88s/it]
Validation Error in epoch 72: 0.13660001755
 74%|███████▍  | 74/100 [03:43<01:14,  2.86s/it]
Validation Error in epoch 73: 0.13559997082
Validation Error in epoch 74: 0.13580000401
 76%|███████▌  | 76/100 [03:49<01:08,  2.86s/it]
Validation Error in epoch 75: 0.13620001078
 77%|███████▋  | 77/100 [03:52<01:05,  2.86s/it]
Validation Error in epoch 76: 0.13499999046
Validation Error in epoch 77: 0.13499999046
 79%|███████▉  | 79/100 [03:57<01:00,  2.88s/it]
Validation Error in epoch 78: 0.13580000401
 80%|████████  | 80/100 [04:00<00:57,  2.88s/it]
Validation Error in epoch 79: 0.13539999723
Validation Error in epoch 80: 0.13499999046
 82%|████████▏ | 82/100 [04:06<00:52,  2.90s/it]
Validation Error in epoch 81: 0.13539999723
 83%|████████▎ | 83/100 [04:09<00:49,  2.90s/it]
Validation Error in epoch 82: 0.13520002365
Validation Error in epoch 83: 0.13620001078
 85%|████████▌ | 85/100 [04:15<00:43,  2.88s/it]
Validation Error in epoch 84: 0.13539999723
 86%|████████▌ | 86/100 [04:18<00:40,  2.86s/it]
Validation Error in epoch 85: 0.13499999046
Validation Error in epoch 86: 0.13419997692
 88%|████████▊ | 88/100 [04:23<00:34,  2.88s/it]
Validation Error in epoch 87: 0.13559997082
 89%|████████▉ | 89/100 [04:26<00:31,  2.87s/it]
Validation Error in epoch 88: 0.13539999723
 90%|█████████ | 90/100 [04:29<00:28,  2.87s/it]
Validation Error in epoch 89: 0.13440001011
Validation Error in epoch 90: 0.13639998436
 92%|█████████▏| 92/100 [04:35<00:22,  2.87s/it]
Validation Error in epoch 91: 0.13419997692
 93%|█████████▎| 93/100 [04:38<00:20,  2.87s/it]
Validation Error in epoch 92: 0.13520002365
Validation Error in epoch 93: 0.13520002365
 95%|█████████▌| 95/100 [04:44<00:14,  2.88s/it]
Validation Error in epoch 94: 0.13499999046
 96%|█████████▌| 96/100 [04:46<00:11,  2.86s/it]
Validation Error in epoch 95: 0.13499999046
Validation Error in epoch 96: 0.13520002365
 97%|█████████▋| 97/100 [04:49<00:08,  2.90s/it]

In [ ]: